# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import os
import shutil
import tempfile
import textwrap
from pathlib import Path
from typing import Any, Dict, Optional, Tuple, Union

import gym
import numpy as np
import hydra

try:
    import huggingface_hub  # noqa
except ImportError:
    raise ImportError(
        "You need to install huggingface_hub to upload and "
        "download models from the Hub"
        "See https://pypi.org/project/huggingface-hub/ for installation."
    )

from huggingface_hub import (
    create_repo,
    delete_folder,
    hf_hub_download,
    list_repo_files,
    upload_file,
    upload_folder,
)
from huggingface_hub.repocard import metadata_eval_result, metadata_save
from omegaconf import OmegaConf

from ..models import Model
from ..planning import Agent, sac_wrapper
from ..third_party.pytorch_sac import VideoRecorder


def get_model_id(
    org_name: str,
    env_id: str,
    tag: Optional[str] = None,
    algo_name: Optional[str] = None,
) -> str:
    """
    Construct a name for the model on the hub.

    Args:
        org_name (str): HF organization to upload the model to.
        env_id (str): environment name for the data (often from Gym
            registration).
        tag (str, optional): Tag for the model to differentiate from
            models in the same environment-algo pairing.
        algo_name (str, optional): Name for the algorithm (if one was
            used to gather the data and train the model).

    Returns:
        name: string combining the various naming components in a
            deterministic manner (for ease of use on the Hub)
    """
    env_id = env_id.replace("/", "-")
    name = f"{org_name}/{algo_name}-{env_id}"
    if tag is not None:
        name = f"{name}_{tag}"
    return name


def generate_metadata(
    model_name: str,
    env_id: str,
    mean_reward: Optional[float] = None,
    std_reward: Optional[float] = None,
) -> Dict[str, Any]:
    """
    Define and generate the tags for the model card.

    Args:
        model_name (str): name of the model.
        env_id (str): name of the environment.
        mean_reward (float, optional): mean reward of the agent.
        std_reward (float, optional): standard deviation of the mean
            reward of the agent.
    """
    metadata = {}
    metadata["library_name"] = "mbrl-lib"
    metadata["tags"] = [
        env_id,
        "deep-reinforcement-learning",
        "reinforcement-learning",
        "mbrl-lib",
    ]  # type: ignore

    metrics_dict = {}
    if mean_reward:
        metrics_dict["metrics_pretty_name"] = "mean_reward"
        metrics_dict["metrics_id"] = "mean_reward"
        metrics_dict["metrics_value"] = f"{mean_reward:.2f}"
        if std_reward:
            metrics_dict["metrics_value"] += f" +/- {std_reward:.2f}"

        # Add metrics
        eval_dict = metadata_eval_result(
            model_pretty_name=model_name,
            task_pretty_name="reinforcement-learning",
            task_id="reinforcement-learning",
            dataset_pretty_name=env_id,
            dataset_id=env_id,
            **metrics_dict,  # type: ignore
        )
        metadata.update(eval_dict)

    return metadata


def _generate_model_card(
    model_name: str,
    env_id: str,
    mean_reward: Optional[float] = None,
    std_reward: Optional[float] = None,
) -> Tuple[str, Dict[str, Any]]:
    """
    Generate the model card for the Hub

    Args:
        model_name (str): name of the model.
        env_id (str): name of the environment.
        mean_reward (float, optional): mean reward of the agent.
        std_reward (float, optional): standard deviation of the mean reward
            of the agent.
    """
    # Step 1: Select the tags
    metadata = generate_metadata(model_name, env_id, mean_reward, std_reward)

    # Step 2: Generate the model card
    model_card = textwrap.dedent(
        f"""\
            # **{model_name}** Agent playing **{env_id}**
            This is a trained model of a **{model_name}** agent
            playing **{env_id}**
            using [MBRL-Lib](https://github.com/facebookresearch/mbrl-lib).

            ## Usage (with MBRL-Lib)
            TODO: Add your code
            ```python
            from mbrl import ...
            ...
            ```
            """
    )

    return model_card, metadata


def _save_model_card(
    local_path: Path, generated_model_card: str, metadata: Dict[str, Any]
):
    """
    Saves a model card for the repository.

    Args:
        local_path (Path): repository directory
        generated_model_card (str): model card generated by
            _generate_model_card()
        metadata (dict): metadata
    """
    readme_path = local_path / "README.md"
    readme = ""
    if readme_path.exists():
        with readme_path.open("r", encoding="utf8") as f:
            readme = f.read()
    else:
        readme = generated_model_card

    with readme_path.open("w", encoding="utf-8") as f:
        f.write(readme)

    # Save our metrics to Readme metadata
    metadata_save(readme_path, metadata)


def _add_logdir(local_path: Path, logdir: Path, repo_id: str):
    """
    Adds a logdir to the repository.

    Args:
        local_path (Path): repository directory.
        logdir (path): logdir directory.
        repo_id (str): id of the model repository from the Hugging Face Hub.
    """
    if logdir.exists() and logdir.is_dir():
        # Add the logdir to the repository under new dir called logs
        repo_logdir = local_path / "logs"

        files = list_repo_files(repo_id=repo_id)
        del_logs = any(["logs/" in f for f in files])
        if del_logs:
            print(
                "Found existing logs on Hub, deleting and replacing with \
                new logs."
            )
            delete_folder("logs", repo_id=repo_id)
        # Delete current logs if they exist
        if repo_logdir.exists():
            shutil.rmtree(repo_logdir)

        # Copy logdir into repo logdir
        shutil.copytree(logdir, repo_logdir)


def evaluate(
    env: gym.Env,
    agent: Agent,
    num_episodes: int = 10,
    video_recorder: Optional[VideoRecorder] = None,
) -> np.ndarray:
    """
    Evaluate the agent in the environment and get an array of returns.

    Args:
        env (gym.Env): environment to evaluate the agent.
        agent (Agent): agent to be evaluated.
        num_episodes (int): number of episodes.
        video_recorder (VideoRecorder, optional): recorder to save the video
            of the first episode.
    """
    episode_rewards = []
    for episode in range(num_episodes):
        obs = env.reset()
        if video_recorder:
            video_recorder.init(enabled=(episode == 0))
        done = False
        episode_reward = 0.0
        while not done:
            action = agent.act(obs)
            obs, reward, done, _ = env.step(action)
            if video_recorder:
                video_recorder.record(env)
            episode_reward += reward
        episode_rewards.append(episode_reward)
        if video_recorder:
            video_recorder.save("replay.mp4")
    return np.array(episode_rewards)


def save_video(
    agent: Agent, eval_env: gym.Env, save_dir: Path
) -> Tuple[Optional[float], Optional[float]]:
    """
    Evaluate the agent and save video of gameplay.

    Args:
        agent (Agent): agent to evaluate.
        eval_env (gym.Env): environment to evaluate the agent.
        save_dir (Path): directory to save the gameplay video in.
    """
    mean_reward, std_reward = None, None
    with tempfile.TemporaryDirectory() as video_dir:
        try:
            video_recorder = VideoRecorder(video_dir)
            rews = evaluate(eval_env, agent, 10, video_recorder)
            mean_reward, std_reward = rews.mean(), rews.std()
            video_fname = os.path.join(video_recorder.save_dir, "replay.mp4")
            _copy_file(Path(video_fname), save_dir)
        except Exception as e:
            print(str(e))
            # Add a message for video
            print(
                "We are unable to generate a replay of your agent, "
                "the package_to_hub process continues"
            )
            print(
                "Please open an issue at "
                "https://github.com/facebookresearch/mbrl-lib/issues"
            )
    return mean_reward, std_reward


def package_to_hub(
    repo_id: str,
    model: Model,
    env_id: str,
    commit_message: str,
    agent: Optional[Agent] = None,
    eval_env: Optional[gym.Env] = None,
    cfg: Optional[OmegaConf] = None,
    model_name: Optional[str] = None,
    token: Optional[str] = None,
    logs: Optional[Path] = None,
    mean_reward: Optional[float] = None,
    std_reward: Optional[float] = None,
):
    """
    Upload a model to Hugging Face Hub by creating a new repository.
    This method does the complete pipeline:
    - It generates the model card
    - It pushes everything to the hub

    Args:
        repo_id (str): id of the model repository from the Hugging Face Hub.
        model (Model): trained model.
        env_id (str): name of the environment.
        commit_message (str): commit message.
        agent (Agent): optional agent included in model.
        eval_env (gym.Env): environment used to evaluate the agent.
        cfg (OmegaConf, dict, optional): Hydra config used to create the model.
        model_name (str, optional): name of the architecture of your model.
            Defaults to `model.__class__.__name__` if not provided.
        token (str, optional): optional token for HF API.
        logs (Path, optional): directory on local machine of tensorboard logs
            you'd like to upload.
        mean_reward (float, optional): mean reward obtained by the agent in
            the environment.
        std_reward (float, optional): std of reward obtained by the agent in
            the environment.
    """

    msg = (
        "This function will save your agent, "
        "create a model card and push everything to the hub. "
        "This is a work in progress: if you encounter a bug, "
        "please open an issue."
    )
    print(msg)

    repo_url = create_repo(
        repo_id=repo_id,
        token=token,
        private=False,
        exist_ok=True,
    )

    if model_name is None:
        model_name = str(model.__class__.__name__)
        if agent is not None:
            model_name += f" w/ {agent.__class__.__name__}"

    with tempfile.TemporaryDirectory() as save_dir_str:
        save_dir = Path(save_dir_str)

        model_dir = save_dir / "model"
        model_dir.mkdir(parents=True, exist_ok=True)
        model.save(model_dir)

        if cfg:
            with open(save_dir / "config.yaml", "w") as outfile:
                OmegaConf.save(cfg, outfile, resolve=True)

        if agent:
            # TODO: Add general agent saving functionality once
            # there is a standard save/load mechanism.
            agent_dir = save_dir / "agent"
            agent_dir.mkdir(parents=True, exist_ok=True)
            # agent.save(agent_dir)
            if isinstance(agent, sac_wrapper.SACAgent):
                ckpt_path = agent_dir / "checkpoint.pth"
                agent.sac_agent.save_checkpoint(ckpt_path=ckpt_path)
            else:
                print(f"Agent saving behavior not implemented for {type(agent)}.")

            mean_reward, std_reward = save_video(agent, eval_env, save_dir)

        generated_model_card, metadata = _generate_model_card(
            model_name, env_id, mean_reward, std_reward
        )
        _save_model_card(save_dir, generated_model_card, metadata)

        if logs:
            _add_logdir(save_dir, Path(logs), repo_id)

        print(f"Pushing repo {repo_id} to the Hugging Face Hub")

        repo_url = upload_folder(
            repo_id=repo_id,
            folder_path=save_dir,
            path_in_repo="",
            commit_message=commit_message,
            token=token,
        )

        print(
            f"Your model is pushed to the Hub. You can view your model "
            f"here: {repo_url}"
        )
    return repo_url


def _copy_file(filepath: Path, dst_directory: Path) -> Path:
    """
    Copy the file to the correct directory

    Args:
        filepath (Path): path of the file.
        dst_directory (Path): destination directory.
    """
    dst = dst_directory / filepath.name
    shutil.copy(str(filepath), str(dst))
    return dst


def push_to_hub(
    repo_id: str,
    filename: Union[str, Path],
    commit_message: str,
    upload_path: str = "",
    token: Optional[str] = None,
    verbose: bool = True,
) -> str:
    """
    Upload a model to Hugging Face Hub (this is for updating an
        existing repo!).

    Args:
        repo_id (str): repo_id id of the model repository from the
            Hugging Face Hub.
        filename (str, Path): name of the model zip or mp4 file from the
            repository.
        commit_message (str): commit message.
        token (str, optional): optional token for HF API.
    """

    filename = Path(filename).resolve()
    create_repo(
        repo_id=repo_id,
        token=token,
        private=False,
        exist_ok=True,
    )

    # Add the model
    with tempfile.TemporaryDirectory() as save_dir_str:
        save_dir = Path(save_dir_str)
        path_in_repo = os.path.join(upload_path, filename.name)
        dst = _copy_file(filename, save_dir)
        print(f"Pushing file to the {repo_id} repo on Hugging Face Hub")
        repo_url = upload_file(
            path_or_fileobj=str(dst),
            path_in_repo=path_in_repo,
            repo_id=repo_id,
            commit_message=commit_message,
            token=token,
        )

    if verbose:
        print(
            f"Your file has been uploaded to the Hub, you can find it "
            f"here: {repo_url}"
        )
    return repo_url


def load_from_hub(
    repo_id: str,
    filename: str,
    token: Optional[str] = None,
) -> str:
    """
    Download a file from Hugging Face Hub.

    Args:
        repo_id (str): id of the model repository from the Hugging Face Hub.
        filename (str): name of the file to download from the repository.
        token (str, optional): optional token for HF API.
    """
    # Get the model from the Hub, download and cache
    # the model on your local disk
    downloaded_model_file = hf_hub_download(
        repo_id=repo_id,
        filename=filename,
        library_name="mbrl-lib",
        token=token,
    )

    return downloaded_model_file


def download_folder_from_hub(
    repo_id: str, folder_name: str, save_dir: Union[str, Path], token: str = None
):
    """
    Download a folder from a repository.

    Args:
        repo_id (str): id of the repository from the Hugging Face Hub.
        folder_name (str, Path): directory to download from the repository
        save_dir (str, Path): directory to save the folder
        token (str, optional): optional token for HF API.
    """
    save_dir = Path(save_dir)
    assert save_dir.exists()

    files = list_repo_files(repo_id=repo_id)
    folder_files = [f for f in files if f"{folder_name}/" in f]
    cached_files = []
    for f in folder_files:
        downloaded_file = load_from_hub(repo_id, f, token)
        cached_files.append(downloaded_file)

    for cached_file in cached_files:
        _copy_file(Path(cached_file), save_dir)


def load_model_from_hub(repo_id: str, save_dir: Union[str, Path], token: str = None):
    """
    Download a saved model from a repository.

    Args:
        repo_id (str): id of the repository from the Hugging Face Hub.
        save_dir (str, Path): directory to save the model weights, environment
            normalizers, etc.
        token (str, optional): optional token for HF API.
    """
    save_dir = Path(save_dir)
    if not save_dir.exists():
        print("Creating directory for saving model params.")
        save_dir.mkdir(parents=True, exist_ok=True)

    download_folder_from_hub(repo_id, "model", save_dir, token)


def load_agent_from_hub(repo_id: str, save_dir: Union[str, Path], token: str = None):
    """
    Download a saved agent from a repository.

    Args:
        repo_id (str): id of the repository from the Hugging Face Hub.
        save_dir (str, Path): directory to save the agent params, etc.
        token (str, optional): optional token for HF API.
    """
    save_dir = Path(save_dir)
    if not save_dir.exists():
        print("Creating directory for saving agent params.")
        save_dir.mkdir(parents=True, exist_ok=True)

    download_folder_from_hub(repo_id, "agent", save_dir, token)


def create_agent_from_hub(repo_id: str, token: str = None):
    """
    Create an agent from config file and weights in a repository.
    Currently works only for SAC agents

    Args:
        repo_id (str): id of the repository from the Hugging Face Hub.
        token (str, optional): optional token for HF API.
    """
    files = list_repo_files(repo_id=repo_id)
    config_exists = "config.yaml" in files
    if not config_exists:
        raise RuntimeError(
            "Can't find the yaml file in HF " "repository to create agent."
        )

    with tempfile.TemporaryDirectory() as temp_dir_str:
        temp_dir = Path(temp_dir_str)
        cached_file = load_from_hub(repo_id, "config.yaml", token)
        _copy_file(Path(cached_file), temp_dir)

        load_agent_from_hub(repo_id, temp_dir, token)

        cfg = OmegaConf.load(temp_dir / "config.yaml")

        if (
            cfg.algorithm.agent._target_
            == "mbrl.third_party.pytorch_sac_pranz24.sac.SAC"
        ):
            import mbrl.third_party.pytorch_sac_pranz24 as pytorch_sac

            from mbrl.planning.sac_wrapper import SACAgent

            agent: pytorch_sac.SAC = hydra.utils.instantiate(cfg.algorithm.agent)
            agent.load_checkpoint(ckpt_path=temp_dir / "checkpoint.pth")
            return SACAgent(agent)
        else:
            raise ValueError("Invalid agent configuration.")
